{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## The Warm Up option\n",
    "\n",
    "When we run `.fit`, we can choose to first \"warm up\" each model individually (similar to fine-tunning if the model was pre-trained, but this is a general functionality, i.e. no need of a pretrained model) before the joined training begins. \n",
    "\n",
    "There are 3 warming up routines:\n",
    "\n",
    "1. Warm up all trainable layers at once with a triangular one-cycle learning rate (referred as slanted triangular learning rates in Howard & Ruder 2018)\n",
    "2. Gradual warm up inspired by the work of [Felbo et al., 2017](https://arxiv.org/abs/1708.00524) for fine-tunning\n",
    "3. Gradual warm up inspired by the work of [Howard & Ruder 2018](https://arxiv.org/abs/1801.06146) for fine-tunning\n",
    "\n",
    "Currently warming up is only supported without a fully connected `DeepHead`, i.e. if `deephead=None`. In addition, `Felbo` and `Howard` routines only applied to `DeepText` and `DeepImage` models. The `Wide` and `DeepDense` components can also be warmed up together, but only in an \"all at once\" mode."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Warm up all at once\n",
    "\n",
    "The models will be trained for `warm_epochs` using a triangular one-cycle learning rate (slanted triangular learning rate) ranging from `warm_max_lr/10` to `warm_max_lr` (default is 0.01). 10% of the training steps are used to increase the learning rate which then decreases for the remaining 90%. \n",
    "\n",
    "Here all trainable layers are warmed up.\n",
    "\n",
    "\n",
    "To use it, simply:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import torch\n",
    "\n",
    "from pytorch_widedeep.preprocessing import WidePreprocessor, DensePreprocessor\n",
    "from pytorch_widedeep.models import Wide, DeepDense, WideDeep\n",
    "from pytorch_widedeep.metrics import Accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>age</th>\n",
       "      <th>workclass</th>\n",
       "      <th>fnlwgt</th>\n",
       "      <th>education</th>\n",
       "      <th>educational_num</th>\n",
       "      <th>marital_status</th>\n",
       "      <th>occupation</th>\n",
       "      <th>relationship</th>\n",
       "      <th>race</th>\n",
       "      <th>gender</th>\n",
       "      <th>capital_gain</th>\n",
       "      <th>capital_loss</th>\n",
       "      <th>hours_per_week</th>\n",
       "      <th>native_country</th>\n",
       "      <th>income_label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>25</td>\n",
       "      <td>Private</td>\n",
       "      <td>226802</td>\n",
       "      <td>11th</td>\n",
       "      <td>7</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>Machine-op-inspct</td>\n",
       "      <td>Own-child</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>38</td>\n",
       "      <td>Private</td>\n",
       "      <td>89814</td>\n",
       "      <td>HS-grad</td>\n",
       "      <td>9</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Farming-fishing</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>50</td>\n",
       "      <td>United-States</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>28</td>\n",
       "      <td>Local-gov</td>\n",
       "      <td>336951</td>\n",
       "      <td>Assoc-acdm</td>\n",
       "      <td>12</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Protective-serv</td>\n",
       "      <td>Husband</td>\n",
       "      <td>White</td>\n",
       "      <td>Male</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>44</td>\n",
       "      <td>Private</td>\n",
       "      <td>160323</td>\n",
       "      <td>Some-college</td>\n",
       "      <td>10</td>\n",
       "      <td>Married-civ-spouse</td>\n",
       "      <td>Machine-op-inspct</td>\n",
       "      <td>Husband</td>\n",
       "      <td>Black</td>\n",
       "      <td>Male</td>\n",
       "      <td>7688</td>\n",
       "      <td>0</td>\n",
       "      <td>40</td>\n",
       "      <td>United-States</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>18</td>\n",
       "      <td>?</td>\n",
       "      <td>103497</td>\n",
       "      <td>Some-college</td>\n",
       "      <td>10</td>\n",
       "      <td>Never-married</td>\n",
       "      <td>?</td>\n",
       "      <td>Own-child</td>\n",
       "      <td>White</td>\n",
       "      <td>Female</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>30</td>\n",
       "      <td>United-States</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   age  workclass  fnlwgt     education  educational_num      marital_status  \\\n",
       "0   25    Private  226802          11th                7       Never-married   \n",
       "1   38    Private   89814       HS-grad                9  Married-civ-spouse   \n",
       "2   28  Local-gov  336951    Assoc-acdm               12  Married-civ-spouse   \n",
       "3   44    Private  160323  Some-college               10  Married-civ-spouse   \n",
       "4   18          ?  103497  Some-college               10       Never-married   \n",
       "\n",
       "          occupation relationship   race  gender  capital_gain  capital_loss  \\\n",
       "0  Machine-op-inspct    Own-child  Black    Male             0             0   \n",
       "1    Farming-fishing      Husband  White    Male             0             0   \n",
       "2    Protective-serv      Husband  White    Male             0             0   \n",
       "3  Machine-op-inspct      Husband  Black    Male          7688             0   \n",
       "4                  ?    Own-child  White  Female             0             0   \n",
       "\n",
       "   hours_per_week native_country  income_label  \n",
       "0              40  United-States             0  \n",
       "1              50  United-States             0  \n",
       "2              40  United-States             1  \n",
       "3              40  United-States             1  \n",
       "4              30  United-States             0  "
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv('data/adult/adult.csv.zip')\n",
    "# For convenience, we'll replace '-' with '_'\n",
    "df.columns = [c.replace(\"-\", \"_\") for c in df.columns]\n",
    "# binary target\n",
    "df['income_label'] = (df[\"income\"].apply(lambda x: \">50K\" in x)).astype(int)\n",
    "df.drop('income', axis=1, inplace=True)\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "wide_cols = ['education', 'relationship','workclass','occupation','native_country','gender']\n",
    "crossed_cols = [('education', 'occupation'), ('native_country', 'occupation')]\n",
    "cat_embed_cols = [('education',16), ('relationship',8), ('workclass',16), ('occupation',16),('native_country',16)]\n",
    "continuous_cols = [\"age\",\"hours_per_week\"]\n",
    "target_col = 'income_label'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# TARGET\n",
    "target = df[target_col].values\n",
    "\n",
    "# WIDE\n",
    "preprocess_wide = WidePreprocessor(wide_cols=wide_cols, crossed_cols=crossed_cols)\n",
    "X_wide = preprocess_wide.fit_transform(df)\n",
    "\n",
    "# DEEP\n",
    "preprocess_deep = DensePreprocessor(embed_cols=cat_embed_cols, continuous_cols=continuous_cols)\n",
    "X_deep = preprocess_deep.fit_transform(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "wide = Wide(wide_dim=np.unique(X_wide).shape[0], pred_dim=1)\n",
    "deepdense = DeepDense(hidden_layers=[64,32], \n",
    "                      deep_column_idx=preprocess_deep.deep_column_idx,\n",
    "                      embed_input=preprocess_deep.embeddings_input,\n",
    "                      continuous_cols=continuous_cols)\n",
    "model = WideDeep(wide=wide, deepdense=deepdense)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.compile(method='binary', metrics=[Accuracy])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Up until here is identical to the code in notebook `03_Binary_Classification_with_Defaults`. Now you can warm up via the warm up parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 0/153 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warming up wide for 5 epochs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|██████████| 153/153 [00:01<00:00, 92.77it/s, loss=0.598, metrics={'acc': 0.697438128631024}] \n",
      "epoch 2: 100%|██████████| 153/153 [00:01<00:00, 106.70it/s, loss=0.394, metrics={'acc': 0.758272976223991}] \n",
      "epoch 3: 100%|██████████| 153/153 [00:01<00:00, 101.83it/s, loss=0.371, metrics={'acc': 0.7821172335542873}]\n",
      "epoch 4: 100%|██████████| 153/153 [00:01<00:00, 107.60it/s, loss=0.365, metrics={'acc': 0.7943976659074041}]\n",
      "epoch 5: 100%|██████████| 153/153 [00:01<00:00, 105.79it/s, loss=0.363, metrics={'acc': 0.8018273488086403}]\n",
      "  0%|          | 0/153 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warming up deepdense for 5 epochs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|██████████| 153/153 [00:02<00:00, 73.75it/s, loss=0.398, metrics={'acc': 0.802813537054573}] \n",
      "epoch 2: 100%|██████████| 153/153 [00:02<00:00, 73.16it/s, loss=0.349, metrics={'acc': 0.8077444782842372}]\n",
      "epoch 3: 100%|██████████| 153/153 [00:02<00:00, 72.68it/s, loss=0.343, metrics={'acc': 0.811737005093031}] \n",
      "epoch 4: 100%|██████████| 153/153 [00:02<00:00, 75.04it/s, loss=0.338, metrics={'acc': 0.8150868602075317}]\n",
      "epoch 5: 100%|██████████| 153/153 [00:02<00:00, 74.24it/s, loss=0.335, metrics={'acc': 0.8180226755048243}]\n",
      "  0%|          | 0/153 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|██████████| 153/153 [00:02<00:00, 74.96it/s, loss=0.361, metrics={'acc': 0.8315}]\n",
      "valid: 100%|██████████| 39/39 [00:00<00:00, 136.56it/s, loss=0.365, metrics={'acc': 0.8318}]\n",
      "epoch 2: 100%|██████████| 153/153 [00:02<00:00, 75.09it/s, loss=0.361, metrics={'acc': 0.8317}]\n",
      "valid: 100%|██████████| 39/39 [00:00<00:00, 129.90it/s, loss=0.365, metrics={'acc': 0.8321}]\n",
      "epoch 3: 100%|██████████| 153/153 [00:02<00:00, 73.24it/s, loss=0.36, metrics={'acc': 0.8317}] \n",
      "valid: 100%|██████████| 39/39 [00:00<00:00, 131.37it/s, loss=0.365, metrics={'acc': 0.8321}]\n",
      "epoch 4: 100%|██████████| 153/153 [00:02<00:00, 72.38it/s, loss=0.36, metrics={'acc': 0.832}]  \n",
      "valid: 100%|██████████| 39/39 [00:00<00:00, 130.72it/s, loss=0.365, metrics={'acc': 0.8324}]\n",
      "epoch 5: 100%|██████████| 153/153 [00:02<00:00, 72.24it/s, loss=0.359, metrics={'acc': 0.8322}]\n",
      "valid: 100%|██████████| 39/39 [00:00<00:00, 130.20it/s, loss=0.364, metrics={'acc': 0.8326}]\n"
     ]
    }
   ],
   "source": [
    "model.fit(X_wide=X_wide, X_deep=X_deep, target=target, n_epochs=5, batch_size=256, val_split=0.2, \n",
    "          warm_up=True, warm_epochs=5, warm_max_lr=0.01)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Warm up Gradually: The \"felbo\"  and the \"howard\" routines\n",
    "\n",
    "The Felbo routine can be illustrated as follows:\n",
    "\n",
    "<p align=\"center\">\n",
    "  <img width=\"600\" src=\"figures/felbo_routine.png\">\n",
    "</p>\n",
    "\n",
    "**Figure 1.** The figure can be described as follows: warm up (or train) the last layer for one epoch using a one cycle triangular learning rate. Then warm up the next deeper layer for one epoch, with a learning rate that is a factor of 2.5 lower than the previous learning rate (the 2.5 factor is fixed) while freezing the already warmed up layer(s). Repeat untill all individual layers are warmed. Then warm one last epoch with all warmed layers trainable. The vanishing color gradient in the figure attempts to illustrate the decreasing learning rate. \n",
    "\n",
    "Note that this is not identical to the Fine-Tunning routine described in Felbo et al, 2017, this is why I used the word 'inspired'.\n",
    "\n",
    "The Howard routine can be illustrated as follows:\n",
    "\n",
    "<p align=\"center\">\n",
    "  <img width=\"600\" src=\"figures/howard_routine.png\">\n",
    "</p>\n",
    "\n",
    "**Figure 2.** The figure can be described as follows: warm up (or train) the last layer for one epoch using a one cycle triangular learning rate. Then warm up the next deeper layer for one epoch, with a learning rate that is a factor of 2.5 lower than the previous learning rate (the 2.5 factor is fixed) while keeping the already warmed up layer(s) trainable. Repeat. The vanishing color gradient in the figure attempts to illustrate the decreasing learning rate. \n",
    "\n",
    "Note that I write \"*warm up (or train) the last layer for one epoch [...]*\". However, in practice the user will have to specify the order of the layers to be warmed up. This is another reason why I wrote that the warm up routines I have implemented are **inspired** by the work of Felbo and Howard and not identical to their implemenations.\n",
    "\n",
    "The `felbo` and `howard` routines can be accessed with via the `warm` parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pytorch_widedeep.preprocessing import TextPreprocessor, ImagePreprocessor\n",
    "from pytorch_widedeep.models import DeepText, DeepImage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv('data/airbnb/airbnb_sample.csv')\n",
    "# There are a number of columns that are already binary. Therefore, no need to one hot encode them\n",
    "crossed_cols = (['property_type', 'room_type'],)\n",
    "already_dummies = [c for c in df.columns if 'amenity' in c] + ['has_house_rules']\n",
    "wide_cols = ['is_location_exact', 'property_type', 'room_type', 'host_gender',\n",
    "'instant_bookable'] + already_dummies\n",
    "cat_embed_cols = [(c, 16) for c in df.columns if 'catg' in c] + \\\n",
    "    [('neighbourhood_cleansed', 64), ('cancellation_policy', 16)]\n",
    "continuous_cols = ['latitude', 'longitude', 'security_deposit', 'extra_people']\n",
    "# it does not make sense to standarised Latitude and Longitude\n",
    "already_standard = ['latitude', 'longitude']\n",
    "# text and image colnames\n",
    "text_col = 'description'\n",
    "img_col = 'id'\n",
    "# path to pretrained word embeddings and the images\n",
    "word_vectors_path = 'data/glove.6B/glove.6B.100d.txt'\n",
    "img_path = 'data/airbnb/property_picture'\n",
    "# target\n",
    "target_col = 'yield'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The vocabulary contains 2192 tokens\n",
      "Indexing word vectors...\n",
      "Loaded 400000 word vectors\n",
      "Preparing embeddings matrix...\n",
      "2175 words in the vocabulary had data/glove.6B/glove.6B.100d.txt vectors and appear more than 5 times\n",
      "Reading Images from data/airbnb/property_picture\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "  8%|▊         | 84/1001 [00:00<00:02, 418.96it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Resizing\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1001/1001 [00:02<00:00, 409.78it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Computing normalisation metrics\n"
     ]
    }
   ],
   "source": [
    "target = df[target_col].values\n",
    "\n",
    "prepare_wide = WidePreprocessor(wide_cols=wide_cols, crossed_cols=crossed_cols)\n",
    "X_wide = prepare_wide.fit_transform(df)\n",
    "\n",
    "prepare_deep = DensePreprocessor(embed_cols=cat_embed_cols, continuous_cols=continuous_cols)\n",
    "X_deep = prepare_deep.fit_transform(df)\n",
    "\n",
    "text_processor = TextPreprocessor(word_vectors_path=word_vectors_path, text_col=text_col)\n",
    "X_text = text_processor.fit_transform(df)\n",
    "\n",
    "image_processor = ImagePreprocessor(img_col=img_col, img_path=img_path)\n",
    "X_images = image_processor.fit_transform(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "wide = Wide(wide_dim=np.unique(X_wide).shape[0], pred_dim=1)\n",
    "deepdense = DeepDense( hidden_layers=[64,32], dropout=[0.2,0.2],\n",
    "                      deep_column_idx=prepare_deep.deep_column_idx,\n",
    "                      embed_input=prepare_deep.embeddings_input,\n",
    "                      continuous_cols=continuous_cols)\n",
    "deeptext = DeepText(vocab_size=len(text_processor.vocab.itos),\n",
    "                    hidden_dim=64, n_layers=3, rnn_dropout=0.5,\n",
    "                    embedding_matrix=text_processor.embedding_matrix)\n",
    "deepimage = DeepImage(pretrained=True, head_layers=None)\n",
    "model = WideDeep(wide=wide, deepdense=deepdense, deeptext=deeptext, deepimage=deepimage)\n",
    "model.compile(method='regression')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "let's have a look to the fit method"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "?model.fit"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As you will see the `warm` args are: \n",
    "\n",
    "```python\n",
    "    warm_up: bool = False,\n",
    "    warm_epochs: int = 4,\n",
    "    warm_max_lr: float = 0.01,\n",
    "    warm_deeptext_gradual: bool = False,\n",
    "    warm_deeptext_max_lr: float = 0.01,\n",
    "    warm_deeptext_layers: Optional[List[nn.Module]] = None,\n",
    "    warm_deepimage_gradual: bool = False,\n",
    "    warm_deepimage_max_lr: float = 0.01,\n",
    "    warm_deepimage_layers: Optional[List[nn.Module]] = None,\n",
    "    warm_routine: str = \"howard\",\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We need to explicitly indicate 1) that we want to warm up, 2) that we want `DeepText` and/or `DeepImage` to warm up gradually 3) in that case, the warm up routine and 4) the layers we want to warm up. \n",
    "\n",
    "For example, let's have a look to the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "WideDeep(\n",
       "  (wide): Wide(\n",
       "    (wide_linear): Embedding(357, 1, padding_idx=0)\n",
       "  )\n",
       "  (deepdense): Sequential(\n",
       "    (0): DeepDense(\n",
       "      (embed_layers): ModuleDict(\n",
       "        (emb_layer_accommodates_catg): Embedding(4, 16)\n",
       "        (emb_layer_bathrooms_catg): Embedding(4, 16)\n",
       "        (emb_layer_bedrooms_catg): Embedding(5, 16)\n",
       "        (emb_layer_beds_catg): Embedding(5, 16)\n",
       "        (emb_layer_cancellation_policy): Embedding(6, 16)\n",
       "        (emb_layer_guests_included_catg): Embedding(4, 16)\n",
       "        (emb_layer_host_listings_count_catg): Embedding(5, 16)\n",
       "        (emb_layer_minimum_nights_catg): Embedding(4, 16)\n",
       "        (emb_layer_neighbourhood_cleansed): Embedding(33, 64)\n",
       "      )\n",
       "      (embed_dropout): Dropout(p=0.0, inplace=False)\n",
       "      (dense): Sequential(\n",
       "        (dense_layer_0): Sequential(\n",
       "          (0): Linear(in_features=196, out_features=64, bias=True)\n",
       "          (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "          (2): Dropout(p=0.2, inplace=False)\n",
       "        )\n",
       "        (dense_layer_1): Sequential(\n",
       "          (0): Linear(in_features=64, out_features=32, bias=True)\n",
       "          (1): LeakyReLU(negative_slope=0.01, inplace=True)\n",
       "          (2): Dropout(p=0.2, inplace=False)\n",
       "        )\n",
       "      )\n",
       "    )\n",
       "    (1): Linear(in_features=32, out_features=1, bias=True)\n",
       "  )\n",
       "  (deeptext): Sequential(\n",
       "    (0): DeepText(\n",
       "      (word_embed): Embedding(2192, 100, padding_idx=1)\n",
       "      (rnn): LSTM(100, 64, num_layers=3, batch_first=True, dropout=0.5)\n",
       "    )\n",
       "    (1): Linear(in_features=64, out_features=1, bias=True)\n",
       "  )\n",
       "  (deepimage): Sequential(\n",
       "    (0): DeepImage(\n",
       "      (backbone): Sequential(\n",
       "        (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
       "        (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (2): ReLU(inplace=True)\n",
       "        (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       "        (4): Sequential(\n",
       "          (0): BasicBlock(\n",
       "            (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (relu): ReLU(inplace=True)\n",
       "            (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "          (1): BasicBlock(\n",
       "            (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (relu): ReLU(inplace=True)\n",
       "            (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (5): Sequential(\n",
       "          (0): BasicBlock(\n",
       "            (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "            (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (relu): ReLU(inplace=True)\n",
       "            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (downsample): Sequential(\n",
       "              (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "              (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            )\n",
       "          )\n",
       "          (1): BasicBlock(\n",
       "            (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (relu): ReLU(inplace=True)\n",
       "            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (6): Sequential(\n",
       "          (0): BasicBlock(\n",
       "            (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "            (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (relu): ReLU(inplace=True)\n",
       "            (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (downsample): Sequential(\n",
       "              (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "              (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            )\n",
       "          )\n",
       "          (1): BasicBlock(\n",
       "            (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (relu): ReLU(inplace=True)\n",
       "            (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (7): Sequential(\n",
       "          (0): BasicBlock(\n",
       "            (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "            (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (relu): ReLU(inplace=True)\n",
       "            (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (downsample): Sequential(\n",
       "              (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "              (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            )\n",
       "          )\n",
       "          (1): BasicBlock(\n",
       "            (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "            (relu): ReLU(inplace=True)\n",
       "            (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "            (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "          )\n",
       "        )\n",
       "        (8): AdaptiveAvgPool2d(output_size=(1, 1))\n",
       "      )\n",
       "    )\n",
       "    (1): Linear(in_features=512, out_features=1, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can see that the `DeepImage` model is comprised by a `Sequential` model that is a ResNet `backbone` and a `Linear` Layer. I want to warm up the layers in the ResNet `backbone`, apart from the first sequence `[Conv2d -> BatchNorm2d -> ReLU -> MaxPool2d]`, and the `Linear` layer, so let's access them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[Sequential(\n",
       "   (0): BasicBlock(\n",
       "     (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "     (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "     (relu): ReLU(inplace=True)\n",
       "     (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "     (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "   )\n",
       "   (1): BasicBlock(\n",
       "     (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "     (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "     (relu): ReLU(inplace=True)\n",
       "     (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "     (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "   )\n",
       " ),\n",
       " Sequential(\n",
       "   (0): BasicBlock(\n",
       "     (conv1): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "     (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "     (relu): ReLU(inplace=True)\n",
       "     (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "     (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "     (downsample): Sequential(\n",
       "       (0): Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "       (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "     )\n",
       "   )\n",
       "   (1): BasicBlock(\n",
       "     (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "     (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "     (relu): ReLU(inplace=True)\n",
       "     (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "     (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "   )\n",
       " ),\n",
       " Sequential(\n",
       "   (0): BasicBlock(\n",
       "     (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "     (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "     (relu): ReLU(inplace=True)\n",
       "     (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "     (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "     (downsample): Sequential(\n",
       "       (0): Conv2d(128, 256, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "       (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "     )\n",
       "   )\n",
       "   (1): BasicBlock(\n",
       "     (conv1): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "     (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "     (relu): ReLU(inplace=True)\n",
       "     (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "     (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "   )\n",
       " ),\n",
       " Sequential(\n",
       "   (0): BasicBlock(\n",
       "     (conv1): Conv2d(256, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "     (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "     (relu): ReLU(inplace=True)\n",
       "     (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "     (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "     (downsample): Sequential(\n",
       "       (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "       (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "     )\n",
       "   )\n",
       "   (1): BasicBlock(\n",
       "     (conv1): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "     (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "     (relu): ReLU(inplace=True)\n",
       "     (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "     (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "   )\n",
       " ),\n",
       " Linear(in_features=512, out_features=1, bias=True)]"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "first_child = list(model.deepimage.children())[0]\n",
    "img_layers = list(first_child.backbone.children())[4:8] + [list(model.deepimage.children())[1]]\n",
    "img_layers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The layers need to be passed in the order that we want them to be warmed up. In the future I might infer this automatically within the `_warmup.py` submodule, but for now, the user needs to specify the warm up order. In this case, is pretty straightforward."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "warm_img_layers = img_layers[::-1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 0/25 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warming up wide for 1 epochs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|██████████| 25/25 [00:00<00:00, 34.36it/s, loss=1.64e+4]\n",
      "  0%|          | 0/25 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warming up deepdense for 1 epochs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|██████████| 25/25 [00:00<00:00, 46.93it/s, loss=1.37e+4]\n",
      "  0%|          | 0/25 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warming up deeptext for 1 epochs\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|██████████| 25/25 [00:04<00:00,  5.53it/s, loss=1.74e+4]\n",
      "  0%|          | 0/25 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warming up deepimage, layer 1 of 5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|██████████| 25/25 [01:05<00:00,  2.63s/it, loss=1.41e+4]\n",
      "  0%|          | 0/25 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warming up deepimage, layer 2 of 5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|██████████| 25/25 [01:29<00:00,  3.57s/it, loss=1.17e+4]\n",
      "  0%|          | 0/25 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warming up deepimage, layer 3 of 5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|██████████| 25/25 [01:51<00:00,  4.46s/it, loss=1.11e+4]\n",
      "  0%|          | 0/25 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warming up deepimage, layer 4 of 5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|██████████| 25/25 [02:17<00:00,  5.48s/it, loss=1.11e+4]\n",
      "  0%|          | 0/25 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Warming up deepimage, layer 5 of 5\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|██████████| 25/25 [02:50<00:00,  6.83s/it, loss=1.08e+4]\n",
      "  0%|          | 0/25 [00:00<?, ?it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Training\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "epoch 1: 100%|██████████| 25/25 [01:10<00:00,  2.83s/it, loss=1.45e+4]\n",
      "valid: 100%|██████████| 7/7 [00:14<00:00,  2.00s/it, loss=1.19e+4]\n"
     ]
    }
   ],
   "source": [
    "model.fit(X_wide=X_wide, X_deep=X_deep, X_text=X_text, X_img=X_images, target=target, n_epochs=1, \n",
    "          batch_size=32, val_split=0.2, warm_up=True, warm_epochs=1, warm_deepimage_gradual=True, \n",
    "          warm_deepimage_layers=warm_img_layers, warm_deepimage_max_lr=0.01, warm_routine='howard')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And one would access to the `felbo` routine by changing the `param`, `warm_routine` to `'felbo'` "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
